{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.RNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN\n",
    "\n",
    "> These are RNN, LSTM and GRU PyTorch implementations created by Ignacio Oguiza - oguiza@gmail.com based on:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "# example\n",
    "# rnn = nn.LSTM(10, 20, 2)\n",
    "# input = torch.randn(5, 3, 10)\n",
    "# h0 = torch.randn(2, 3, 20)\n",
    "# c0 = torch.randn(2, 3, 20)\n",
    "# output, (hn, cn) = rnn(input, (h0, c0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class _RNN_Base(Module):\n",
    "    def __init__(self, c_in, c_out, hidden_size=100, n_layers=1, bias=True, rnn_dropout=0, bidirectional=False, fc_dropout=0.):\n",
    "        self.rnn = self._cell(c_in, hidden_size, num_layers=n_layers, bias=bias, batch_first=True, dropout=rnn_dropout, bidirectional=bidirectional)\n",
    "        self.dropout = nn.Dropout(fc_dropout) if fc_dropout else noop\n",
    "        self.fc = nn.Linear(hidden_size * (1 + bidirectional), c_out)\n",
    "\n",
    "    def forward(self, x): \n",
    "        x = x.transpose(2,1)    # [batch_size x n_vars x seq_len] --> [batch_size x seq_len x n_vars]\n",
    "        output, _ = self.rnn(x) # output from all sequence steps: [batch_size x seq_len x hidden_size * (1 + bidirectional)]\n",
    "        output = output[:, -1]  # output from last sequence step : [batch_size x hidden_size * (1 + bidirectional)]\n",
    "        output = self.fc(self.dropout(output))\n",
    "        return output\n",
    "        \n",
    "class RNN(_RNN_Base):\n",
    "    _cell = nn.RNN\n",
    "    \n",
    "class LSTM(_RNN_Base):\n",
    "    _cell = nn.LSTM\n",
    "    \n",
    "class GRU(_RNN_Base):\n",
    "    _cell = nn.GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "c_in = 3\n",
    "seq_len = 12\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(RNN(c_in, c_out, hidden_size=100, n_layers=2, bias=True, rnn_dropout=0.2, bidirectional=True, fc_dropout=0.5)(xb).shape, [bs, c_out])\n",
    "test_eq(RNN(c_in, c_out)(xb).shape, [bs, c_out])\n",
    "test_eq(RNN(c_in, c_out, hidden_size=100, n_layers=2, bias=True, rnn_dropout=0.2, bidirectional=True, fc_dropout=0.5)(xb).shape, [bs, c_out])\n",
    "test_eq(LSTM(c_in, c_out)(xb).shape, [bs, c_out])\n",
    "test_eq(GRU(c_in, c_out)(xb).shape, [bs, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.809702</td>\n",
       "      <td>1.783900</td>\n",
       "      <td>0.188889</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tsai.data.all import *\n",
    "dsid = 'NATOPS' \n",
    "bs = 16\n",
    "X, y, splits = get_UCR_data(dsid, return_split=False)\n",
    "tfms  = [None, [Categorize()]]\n",
    "dsets = TSDatasets(X, y, tfms=tfms, splits=splits)\n",
    "dls   = TSDataLoaders.from_dsets(dsets.train, dsets.valid, bs=bs, num_workers=0, shuffle=False)\n",
    "model = LSTM(dls.vars, dls.c)\n",
    "learn = Learner(dls, model,  metrics=accuracy)\n",
    "learn.fit_one_cycle(1, 3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN(\n",
      "  (rnn): RNN(3, 100, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      "  (fc): Linear(in_features=200, out_features=2, bias=True)\n",
      ")\n",
      "(81802, True)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = RNN(c_in, c_out, hidden_size=100,n_layers=2,bidirectional=True,rnn_dropout=.5,fc_dropout=.5)\n",
    "print(m)\n",
    "print(total_params(m))\n",
    "m(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LSTM(\n",
      "  (rnn): LSTM(3, 100, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      "  (fc): Linear(in_features=200, out_features=2, bias=True)\n",
      ")\n",
      "(326002, True)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = LSTM(c_in, c_out, hidden_size=100,n_layers=2,bidirectional=True,rnn_dropout=.5,fc_dropout=.5)\n",
    "print(m)\n",
    "print(total_params(m))\n",
    "m(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GRU(\n",
      "  (rnn): GRU(3, 100, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      "  (fc): Linear(in_features=200, out_features=2, bias=True)\n",
      ")\n",
      "(244602, True)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = GRU(c_in, c_out, hidden_size=100,n_layers=2,bidirectional=True,rnn_dropout=.5,fc_dropout=.5)\n",
    "print(m)\n",
    "print(total_params(m))\n",
    "m(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "out = create_scripts()\n",
    "beep(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
